{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert the model downloaded in pytorch_checkpoints/release_model_simple/checkpoint_14.pth.tar to a standard pytorch resnet-50 format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
     "import torch\n",
     "import torchvision\n",
     "import torchvision.models.resnet as resnet\n",
     "import torch.nn as nn\n",
     "\n",
     "model_path = 'pytorch_checkpoints/release_model_simple/checkpoint_14.pth.tar'\n",
     "model_state = torch.load(model_path, map_location='cpu')['state_dict']\n",
     "\n",
     "net = resnet.resnet50()\n",
     "net_state = net.state_dict()\n",
     "\n",
     "\n",
     "for k in [k for k in model_state.keys() if 'encoderVideo' in k]:\n",
     "    kk = k.replace('module.encoderVideo.', '')\n",
     "    tmp = model_state[k]\n",
     "    if net_state[kk].shape != model_state[k].shape and net_state[kk].dim() == 4 and model_state[k].dim() == 5:\n",
     "        tmp = model_state[k].squeeze(2)\n",
     "    net_state[kk][:] = tmp[:]\n",
     "    \n",
     "net.load_state_dict(net_state)\n",
     "\n",
     "afterconv1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False)\n",
     "relu_layer = nn.ReLU(inplace=True)\n",
     "\n",
     "afterconv1_weight = model_state['module.afterconv1.weight']\n",
     "afterconv1_weight = afterconv1_weight.squeeze(2)\n",
     "\n",
     "x = afterconv1.state_dict()\n",
     "x['weight'][:] = afterconv1_weight[:]\n",
     "\n",
     "layers = list(net.children())[:-3]\n",
     "layers.append(afterconv1)\n",
     "layers.append(relu_layer)\n",
     "\n",
     "net = nn.Sequential(*layers)\n",
     "\n",
     "inp = torch.randn(1, 3, 240, 240)\n",
     "out = net(inp)\n",
     "print(out.size())\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
